//
// Copyright 2023 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cmd

import (
	"context"
	"fmt"
	"net/http"
	"os"
	"strings"

	"github.com/guacsec/guac/internal/testing/ptrfrom"
	model "github.com/guacsec/guac/pkg/assembler/clients/generated"
	"github.com/guacsec/guac/pkg/assembler/helpers"
	"github.com/guacsec/guac/pkg/cli"
	"github.com/guacsec/guac/pkg/guacanalytics"
	"github.com/guacsec/guac/pkg/logging"

	"github.com/Khan/genqlient/graphql"
	"github.com/jedib0t/go-pretty/v6/table"
	"github.com/spf13/cobra"
	"github.com/spf13/viper"
)

const (
	guacType     string = "guac"
	noVulnType   string = "novuln"
	purlType     string = "purl"
	uriType      string = "uri"
	artifactType string = "artifact"
)

type queryOptions struct {
	graphqlEndpoint string
	headerFile      string
	searchString    string
	isPurl          bool
	vulnerabilityID string
	depth           int
	pathsToReturn   int
	inputType       string
}

var queryVulnCmd = &cobra.Command{
	Use:   "vuln [flags] <type> <input>",
	Short: "query if a package is affected by the specified vulnerability",
	Long: `The vuln command allows you to query whether a specific package, SBOM URI, or artifact is affected by a given vulnerability.

Positional Arguments:
  <type>    Specify the input type: 'artifact', 'uri', or 'purl'
  <input>   The corresponding input based on the specified type`,
	Run: func(cmd *cobra.Command, args []string) {
		ctx := logging.WithLogger(context.Background())

		// Ensure exactly two arguments are provided: <type> and <input>
		if len(args) != 2 {
			fmt.Println("error: Exactly two arguments must be provided: <type> and <input>.")
			_ = cmd.Help()
			os.Exit(1)
		}

		// Extract and validate other flags
		opts, err := validateQueryVulnFlags(
			viper.GetString("gql-addr"),
			viper.GetString("header-file"),
			viper.GetString("vuln-id"),
			viper.GetInt("search-depth"),
			viper.GetInt("num-path"),
			args,
		)
		if err != nil {
			fmt.Printf("Unable to validate flags: %v\n", err)
			_ = cmd.Help()
			os.Exit(1)
		}

		httpClient := http.Client{Transport: cli.HTTPHeaderTransport(ctx, opts.headerFile, http.DefaultTransport)}
		gqlclient := graphql.NewClient(opts.graphqlEndpoint, &httpClient)

		t := table.NewWriter()
		tTemp := table.Table{}
		tTemp.Render()
		t.AppendHeader(rowHeader)

		// Process based on the specified input type
		if opts.vulnerabilityID != "" {
			printVulnInfoByVulnID(ctx, gqlclient, t, opts)
		} else {
			printVulnInfo(ctx, gqlclient, t, opts)
		}
	},
}

func getPkgResponseFromPurl(ctx context.Context, gqlclient graphql.Client, purl string) (*model.PackagesResponse, error) {
	pkgInput, err := helpers.PurlToPkg(purl)
	if err != nil {
		return nil, fmt.Errorf("failed to parse PURL: %v", err)
	}

	pkgQualifierFilter := []model.PackageQualifierSpec{}
	for _, qualifier := range pkgInput.Qualifiers {
		// to prevent https://github.com/golang/go/discussions/56010
		qualifier := qualifier
		pkgQualifierFilter = append(pkgQualifierFilter, model.PackageQualifierSpec{
			Key:   qualifier.Key,
			Value: &qualifier.Value,
		})
	}

	pkgFilter := &model.PkgSpec{
		Type:       &pkgInput.Type,
		Namespace:  pkgInput.Namespace,
		Name:       &pkgInput.Name,
		Version:    pkgInput.Version,
		Subpath:    pkgInput.Subpath,
		Qualifiers: pkgQualifierFilter,
	}
	pkgResponse, err := model.Packages(ctx, gqlclient, *pkgFilter)
	if err != nil {
		return nil, fmt.Errorf("error querying for package: %v", err)
	}
	if len(pkgResponse.Packages) != 1 {
		return nil, fmt.Errorf("failed to located package based on purl")
	}
	return pkgResponse, nil
}

func printVulnInfo(ctx context.Context, gqlclient graphql.Client, t table.Writer, opts queryOptions) {
	logger := logging.FromContext(ctx)
	var paths []string
	var tableRows []table.Row

	var depVulnPaths []string
	var depVulnTableRows []table.Row

	switch opts.inputType {
	case artifactType:
		// If it's an artifact, search for SBOMs via artifact
		var err error
		depVulnPaths, depVulnTableRows, err = guacanalytics.SearchForSBOMViaArtifact(ctx, gqlclient, opts.searchString, opts.depth)
		if err != nil {
			logger.Fatalf("error searching via hasSBOM for artifact: %v", err)
		}

		if len(depVulnPaths) == 0 {
			depVulnPaths, depVulnTableRows, err = findConnectedPkgAndSearchViaPkg(ctx, gqlclient, opts)
			if err != nil {
				logger.Fatalf("error finding purl connected to artifact and searching via package: %v", err)
			}
		}
	default:
		// Otherwise, search for SBOMs via package
		var err error
		depVulnPaths, depVulnTableRows, err = guacanalytics.SearchForSBOMViaPkg(ctx, gqlclient, opts.searchString, opts.depth, opts.isPurl)
		if err != nil {
			logger.Fatalf("error searching via hasSBOM for package: %v", err)
		}

		if len(depVulnPaths) == 0 && opts.inputType == purlType {
			depVulnPaths, depVulnTableRows, err = findConnectedArtAndSearchViaArt(ctx, gqlclient, opts)
			if err != nil {
				logger.Fatalf("error finding artifact connected to package and searching via artifact: %v", err)
			}
		}
	}

	paths = append(paths, depVulnPaths...)
	tableRows = append(tableRows, depVulnTableRows...)

	if len(paths) > 0 {
		t.AppendRows(tableRows)
		fmt.Println(t.Render())
		fmt.Printf("Visualizer URL: http://localhost:3000/?path=%v\n", strings.Join(removeDuplicateValuesFromPath(paths), `,`))
	} else {
		fmt.Printf("No path to vulnerabilities found!\n")
	}
}

// findConnectedArtAndSearchViaArt finds the artifact attached to the packages with the given purl.
// After finding the artifact, the graph is searched via that artifact.
func findConnectedArtAndSearchViaArt(ctx context.Context, gqlclient graphql.Client, opts queryOptions) ([]string, []table.Row, error) {
	var depVulnPaths []string
	var depVulnTableRows []table.Row

	// convert the package to an artifact via an isOccurrence.
	pkgSpec, err := helpers.PurlToPkgFilter(opts.searchString)
	if err != nil {
		return nil, nil, fmt.Errorf("error converting purl to pkg %v", err)
	}

	occ, err := model.Occurrences(ctx, gqlclient, model.IsOccurrenceSpec{
		Subject: &model.PackageOrSourceSpec{
			Package: &pkgSpec,
		},
	})
	if err != nil {
		return nil, nil, fmt.Errorf("error getting occurrences for package: %v", err)
	}

	if len(occ.IsOccurrence) > 0 {
		art := occ.IsOccurrence[0].Artifact

		newSearchString := art.Algorithm + ":" + art.Digest

		depVulnPaths, depVulnTableRows, err = guacanalytics.SearchForSBOMViaArtifact(ctx, gqlclient, newSearchString, opts.depth)
		if err != nil {
			return nil, nil, fmt.Errorf("error searching via hasSBOM for artifact: %v", err)
		}
	}

	return depVulnPaths, depVulnTableRows, nil
}

// findConnectedPkgAndSearchViaPkg finds the pkg attached to the artifact.
// After finding the pkg, the graph is searched via that package.
func findConnectedPkgAndSearchViaPkg(ctx context.Context, gqlclient graphql.Client, opts queryOptions) ([]string, []table.Row, error) {
	var depVulnPaths []string
	var depVulnTableRows []table.Row

	occurrence, err := searchArtToPkg(ctx, gqlclient, opts.searchString)
	if err != nil {
		return nil, nil, fmt.Errorf("error searching for package via artifact: %v", err)
	}

	pkg, ok := occurrence.IsOccurrence[0].Subject.(*model.AllIsOccurrencesTreeSubjectPackage)
	if !ok {
		return nil, nil, fmt.Errorf("error converting isOccurrence to package subject")
	}

	depVulnPaths, depVulnTableRows, err = guacanalytics.SearchForSBOMViaPkg(ctx, gqlclient, pkg.Namespaces[0].Names[0].Versions[0].Purl, opts.depth, true)
	if err != nil {
		return nil, nil, fmt.Errorf("error searching via hasSBOM for artifact: %v", err)
	}
	return depVulnPaths, depVulnTableRows, nil
}

func searchArtToPkg(ctx context.Context, gqlclient graphql.Client, searchString string) (*model.OccurrencesResponse, error) {
	split := strings.Split(searchString, ":")
	if len(split) != 2 {
		return nil, fmt.Errorf("failed to parse artifact. Needs to be in algorithm:digest form")
	}
	artifactFilter := model.ArtifactSpec{
		Algorithm: ptrfrom.String(strings.ToLower(split[0])),
		Digest:    ptrfrom.String(strings.ToLower(split[1])),
	}

	o, err := model.Occurrences(ctx, gqlclient, model.IsOccurrenceSpec{
		Artifact: &artifactFilter,
	})
	if err != nil {
		return nil, fmt.Errorf("error querying for occurrences: %v", err)
	}

	return o, nil
}

func printVulnInfoByVulnID(ctx context.Context, gqlclient graphql.Client, t table.Writer, opts queryOptions) {
	logger := logging.FromContext(ctx)
	var tableRows []table.Row

	vulnResponse, err := model.Vulnerabilities(ctx, gqlclient, model.VulnerabilitySpec{VulnerabilityID: &opts.vulnerabilityID})
	if err != nil {
		logger.Fatalf("error querying for vulnerabilities: %v", err)
	}

	if len(vulnResponse.Vulnerabilities) == 0 {
		fmt.Printf("Failed to identify vulnerability. Please ensure certifier has run by running guacone certifier osv\n")
		return
	}
	var path []string

	switch opts.inputType {
	case purlType:
		pkgResponse, err := getPkgResponseFromPurl(ctx, gqlclient, opts.searchString)
		if err != nil {
			logger.Fatalf("getPkgResponseFromPurl - error: %v", err)
		}
		var vulnNeighborError error
		path, tableRows, vulnNeighborError = queryVulnsViaVulnNodeNeighbors(ctx, gqlclient, pkgResponse.Packages[0].Namespaces[0].Names[0].Versions[0].Id, vulnResponse.Vulnerabilities, opts.depth, opts.pathsToReturn)
		if vulnNeighborError != nil {
			logger.Fatalf("error querying neighbor: %v", err)
		}
	case artifactType:
		split := strings.Split(opts.searchString, ":")

		occur, err := searchArtToPkg(ctx, gqlclient, split[0]+":"+split[1])
		if err != nil {
			logger.Fatalf("error searching for package via artifact: %v", err)
		}
		subjectPackage, ok := occur.IsOccurrence[0].Subject.(*model.AllIsOccurrencesTreeSubjectPackage)
		if ok {
			var vulnNeighborError error
			path, tableRows, vulnNeighborError = queryVulnsViaVulnNodeNeighbors(ctx, gqlclient, subjectPackage.Namespaces[0].Names[0].Versions[0].Id, vulnResponse.Vulnerabilities, opts.depth, opts.pathsToReturn)
			if vulnNeighborError != nil {
				logger.Fatalf("error querying neighbor: %v", err)
			}
		}
	case uriType:
		foundHasSBOM, err := model.HasSBOMs(ctx, gqlclient, model.HasSBOMSpec{Uri: &opts.searchString})
		if err != nil {
			logger.Fatalf("failed getting hasSBOM via URI: %s with error: %w", opts.searchString, err)
		}
		if pkgResponse, ok := foundHasSBOM.HasSBOM[0].Subject.(*model.AllHasSBOMTreeSubjectPackage); ok {
			var vulnNeighborError error
			path, tableRows, vulnNeighborError = queryVulnsViaVulnNodeNeighbors(ctx, gqlclient, pkgResponse.Namespaces[0].Names[0].Versions[0].Id, vulnResponse.Vulnerabilities, opts.depth, opts.pathsToReturn)
			if vulnNeighborError != nil {
				logger.Fatalf("error querying neighbor: %v", err)
			}
		} else if artResponse, ok := foundHasSBOM.HasSBOM[0].Subject.(*model.AllHasSBOMTreeSubjectArtifact); ok {
			occur, err := searchArtToPkg(ctx, gqlclient, artResponse.Algorithm+":"+artResponse.Digest)
			if err != nil {
				logger.Fatalf("error searching for package via artifact: %v", err)
			}
			subjectPackage, ok := occur.IsOccurrence[0].Subject.(*model.AllIsOccurrencesTreeSubjectPackage)
			if ok {
				var vulnNeighborError error
				path, tableRows, vulnNeighborError = queryVulnsViaVulnNodeNeighbors(ctx, gqlclient, subjectPackage.Namespaces[0].Names[0].Versions[0].Id, vulnResponse.Vulnerabilities, opts.depth, opts.pathsToReturn)
				if vulnNeighborError != nil {
					logger.Fatalf("error querying neighbor: %v", err)
				}
			}
		} else {
			logger.Fatalf("located hasSBOM does not have a subject that is a package or artifact")
		}
	}

	if len(path) > 0 {
		t.AppendRows(tableRows)
		fmt.Println(t.Render())
		fmt.Printf("Visualizer url: http://localhost:3000/?path=%v\n", strings.Join(removeDuplicateValuesFromPath(path), `,`))
	} else {
		fmt.Printf("No path to vulnerability ID found!\n")
	}
}

func vexSubjectIds(s model.AllCertifyVEXStatementSubjectPackageOrArtifact) []string {
	switch v := s.(type) {
	case *model.AllCertifyVEXStatementSubjectArtifact:
		return []string{v.Id}
	case *model.AllCertifyVEXStatementSubjectPackage:
		return []string{
			v.Id,
			v.Namespaces[0].Id,
			v.Namespaces[0].Names[0].Id,
			v.Namespaces[0].Names[0].Versions[0].Id}
	default:
		return []string{}
	}
}

func queryVulnsViaVulnNodeNeighbors(ctx context.Context, gqlclient graphql.Client, topPkgVersionID string, vulnerabilitiesResponses []model.VulnerabilitiesVulnerabilitiesVulnerability, depth int, pathsToReturn int) ([]string, []table.Row, error) {
	type vulnNeighbor struct {
		node model.NeighborsNeighborsNode
		id   string
	}

	var path []string
	var vulnNodeNeighborResponses []vulnNeighbor
	var tableRows []table.Row

	edgeTypes := []model.Edge{model.EdgeVulnerabilityCertifyVuln, model.EdgeVulnerabilityCertifyVexStatement}
	for _, vulnerabilitiesResponse := range vulnerabilitiesResponses {
		for _, vulnerabilityNodeID := range vulnerabilitiesResponse.VulnerabilityIDs {
			vulnNodeNeighborResponse, err := model.Neighbors(ctx, gqlclient, vulnerabilityNodeID.Id, edgeTypes)
			if err != nil {
				return nil, nil, fmt.Errorf("error querying neighbor for vulnerability: %w", err)
			}
			for _, neighbor := range vulnNodeNeighborResponse.Neighbors {
				vulnNodeNeighborResponses = append(vulnNodeNeighborResponses, vulnNeighbor{neighbor, vulnerabilityNodeID.Id})
			}
		}
	}

	certifyVulnFound := false
	numberOfPaths := 0
	for _, neighbor := range vulnNodeNeighborResponses {
		if certifyVuln, ok := neighbor.node.(*model.NeighborsNeighborsCertifyVuln); ok {
			certifyVulnFound = true
			pkgPath, err := searchDependencyPackagesReverse(ctx, gqlclient, topPkgVersionID, certifyVuln.Package.Namespaces[0].Names[0].Versions[0].Id, depth)
			if err != nil {
				return nil, nil, fmt.Errorf("error searching dependency packages match: %w", err)
			}
			if len(pkgPath) > 0 {
				tableRows = append(tableRows, table.Row{certifyVulnStr, certifyVuln.Id, "vulnerability ID: " + certifyVuln.Vulnerability.VulnerabilityIDs[0].VulnerabilityID})
				fullVulnPath := append([]string{certifyVuln.Vulnerability.Id, certifyVuln.Vulnerability.VulnerabilityIDs[0].Id, certifyVuln.Id,
					certifyVuln.Package.Namespaces[0].Names[0].Versions[0].Id,
					certifyVuln.Package.Namespaces[0].Names[0].Id, certifyVuln.Package.Namespaces[0].Id,
					certifyVuln.Package.Id}, pkgPath...)
				path = append(path, fullVulnPath...)
				numberOfPaths += 1
			}
			if pathsToReturn != 0 && numberOfPaths == pathsToReturn {
				return path, nil, nil
			}
		}
		if certifyVex, ok := neighbor.node.(*model.NeighborsNeighborsCertifyVEXStatement); ok {
			certifyVulnFound = true
			for _, vuln := range certifyVex.Vulnerability.VulnerabilityIDs {
				tableRows = append(tableRows, table.Row{vexLinkStr, certifyVex.Id, "vulnerability ID: " + vuln.VulnerabilityID + ", Vex Status: " + string(certifyVex.Status) + ", Subject: " + guacanalytics.VexSubjectString(certifyVex.Subject)})
				path = append(path, certifyVex.Id, vuln.Id)
			}
			path = append(path, vexSubjectIds(certifyVex.Subject)...)
		}
	}
	if !certifyVulnFound {
		return nil, nil, fmt.Errorf("error certify vulnerability node not found, incomplete data. Please ensure certifier has run by running guacone certifier osv")
	}
	return path, tableRows, nil
}

func searchDependencyPackagesReverse(ctx context.Context, gqlclient graphql.Client, topPkgID string, searchPkgID string, maxLength int) ([]string, error) {
	var path []string
	var collectedIDs []string
	queue := make([]string, 0) // the queue of nodes in bfs
	type dfsNode struct {
		expanded     bool // true once all node neighbors are added to queue
		parent       string
		isDependency *model.NeighborsNeighborsIsDependency
		depth        int
	}
	nodeMap := map[string]dfsNode{}

	nodeMap[searchPkgID] = dfsNode{}
	queue = append(queue, searchPkgID)
	collectedIDs = append(collectedIDs, searchPkgID)

	found := false
	for len(queue) > 0 {
		now := queue[0]
		queue = queue[1:]
		nowNode := nodeMap[now]

		if topPkgID != "" {
			if now == topPkgID {
				found = true
				break
			}
		}

		if maxLength != 0 && nowNode.depth >= maxLength {
			break
		}

		isDependencyNeighborResponses, err := model.Neighbors(ctx, gqlclient, now, []model.Edge{model.EdgePackageIsDependency})
		if err != nil {
			return nil, fmt.Errorf("failed getting package parent:%v", err)
		}
		for _, neighbor := range isDependencyNeighborResponses.Neighbors {
			if isDependency, ok := neighbor.(*model.NeighborsNeighborsIsDependency); ok && now != isDependency.Package.Namespaces[0].Names[0].Versions[0].Id {
				dfsN, seen := nodeMap[isDependency.Package.Namespaces[0].Names[0].Versions[0].Id]
				if !seen {
					dfsN = dfsNode{
						parent:       now,
						isDependency: isDependency,
						depth:        nowNode.depth + 1,
					}
					nodeMap[isDependency.Package.Namespaces[0].Names[0].Versions[0].Id] = dfsN
				}
				if !dfsN.expanded {
					queue = append(queue, isDependency.Package.Namespaces[0].Names[0].Versions[0].Id)
					collectedIDs = append(collectedIDs, isDependency.Package.Namespaces[0].Names[0].Versions[0].Id)
				}
			}
		}
		nowNode.expanded = true
		nodeMap[now] = nowNode
	}

	// not found so return nil
	if topPkgID != "" && !found {
		return nil, nil
	}

	var now string
	if topPkgID != "" {
		now = topPkgID
		for now != searchPkgID {
			if len(nodeMap[now].isDependency.DependencyPackage.Namespaces[0].Names[0].Versions) > 0 {
				path = append(path, nodeMap[now].isDependency.Id, nodeMap[now].isDependency.DependencyPackage.Namespaces[0].Names[0].Versions[0].Id,
					nodeMap[now].isDependency.DependencyPackage.Namespaces[0].Names[0].Id,
					nodeMap[now].isDependency.DependencyPackage.Namespaces[0].Id, nodeMap[now].isDependency.DependencyPackage.Id,
					nodeMap[now].isDependency.Package.Namespaces[0].Names[0].Versions[0].Id,
					nodeMap[now].isDependency.Package.Namespaces[0].Names[0].Id, nodeMap[now].isDependency.Package.Namespaces[0].Id,
					nodeMap[now].isDependency.Package.Id)
			} else {
				path = append(path, nodeMap[now].isDependency.Id, nodeMap[now].isDependency.DependencyPackage.Namespaces[0].Names[0].Id,
					nodeMap[now].isDependency.DependencyPackage.Namespaces[0].Id, nodeMap[now].isDependency.DependencyPackage.Id,
					nodeMap[now].isDependency.Package.Namespaces[0].Names[0].Versions[0].Id,
					nodeMap[now].isDependency.Package.Namespaces[0].Names[0].Id, nodeMap[now].isDependency.Package.Namespaces[0].Id,
					nodeMap[now].isDependency.Package.Id)
			}

			now = nodeMap[now].parent
		}
		return path, nil
	} else {
		for i := len(collectedIDs) - 1; i >= 0; i-- {
			if nodeMap[collectedIDs[i]].isDependency != nil {
				if len(nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Namespaces[0].Names[0].Versions) > 0 {
					path = append(path, nodeMap[collectedIDs[i]].isDependency.Id, nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Namespaces[0].Names[0].Versions[0].Id,
						nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Namespaces[0].Names[0].Id,
						nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Namespaces[0].Id, nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Id,
						nodeMap[collectedIDs[i]].isDependency.Package.Namespaces[0].Names[0].Versions[0].Id,
						nodeMap[collectedIDs[i]].isDependency.Package.Namespaces[0].Names[0].Id, nodeMap[collectedIDs[i]].isDependency.Package.Namespaces[0].Id,
						nodeMap[collectedIDs[i]].isDependency.Package.Id)
				} else {
					path = append(path, nodeMap[collectedIDs[i]].isDependency.Id, nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Namespaces[0].Names[0].Id,
						nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Namespaces[0].Id, nodeMap[collectedIDs[i]].isDependency.DependencyPackage.Id,
						nodeMap[collectedIDs[i]].isDependency.Package.Namespaces[0].Names[0].Versions[0].Id,
						nodeMap[collectedIDs[i]].isDependency.Package.Namespaces[0].Names[0].Id, nodeMap[collectedIDs[i]].isDependency.Package.Namespaces[0].Id,
						nodeMap[collectedIDs[i]].isDependency.Package.Id)
				}
			}
		}
		return path, nil
	}
}

func removeDuplicateValuesFromPath(path []string) []string {
	keys := make(map[string]bool)
	var list []string

	for _, entry := range path {
		if _, value := keys[entry]; !value {
			keys[entry] = true
			list = append(list, entry)
		}
	}
	return list
}

func validateQueryVulnFlags(graphqlEndpoint, headerFile, vulnID string, depth, path int, args []string) (queryOptions, error) {
	var opts queryOptions
	opts.graphqlEndpoint = graphqlEndpoint
	opts.headerFile = headerFile
	opts.vulnerabilityID = vulnID
	opts.depth = depth
	opts.pathsToReturn = path

	if len(args) > 0 {
		validTypes := []string{artifactType, uriType, purlType}

		// Initialize variables to hold type and input
		var typeArg, inputArg string

		// Iterate through arguments to identify type and input, because they might not be in order
		for _, arg := range args {
			lowered := strings.ToLower(arg)
			if contains(validTypes, lowered) {
				if typeArg != "" {
					fmt.Println("error: Multiple types provided. Please specify only one type.")
					os.Exit(1)
				}
				typeArg = lowered
			} else {
				if inputArg != "" {
					fmt.Println("error: Multiple inputs provided. Please specify only one input.")
					os.Exit(1)
				}
				inputArg = arg
			}
		}

		// Validate that typeArg has been set
		if typeArg == "" {
			fmt.Printf("error: Input type not specified or invalid. Valid types are: %v\n", validTypes)
			os.Exit(1)
		}

		_, err := helpers.PurlToPkg(inputArg)
		if err != nil {
			opts.isPurl = false
		} else {
			opts.isPurl = true
		}
		opts.searchString = inputArg
		opts.inputType = typeArg
	} else {
		return opts, fmt.Errorf("expected subject input to be purl or SBOM URI")
	}
	return opts, nil
}

func contains(slice []string, item string) bool {
	for _, a := range slice {
		if a == item {
			return true
		}
	}
	return false
}

func init() {
	set, err := cli.BuildFlags([]string{"vuln-id", "search-depth", "num-path"})
	if err != nil {
		fmt.Fprintf(os.Stderr, "failed to setup flag: %v", err)
		os.Exit(1)
	}
	queryVulnCmd.Flags().AddFlagSet(set)
	if err := viper.BindPFlags(queryVulnCmd.Flags()); err != nil {
		fmt.Fprintf(os.Stderr, "failed to bind flags: %v", err)
		os.Exit(1)
	}

	queryCmd.AddCommand(queryVulnCmd)
}
